#!/usr/bin/python
# -*- coding: utf-8 -*-

import warnings,numpy,argparse,re,sys,os,os.path,unicodedata,multiprocessing
import ocrolib

# disable rank warnings from polyfit
warnings.simplefilter('ignore',numpy.RankWarning) 

parser = argparse.ArgumentParser(description = """
Compute the edit distances between ground truth and recognizer output.
Run with the ground truth files as arguments, and it will find the
corresponnding recognizer output files using the given extension (-x).
Missing output files are handled as empty strings, unless the -s
option is given.
""")
parser.add_argument("files",default=[],nargs='*',help="input lines")
parser.add_argument("-x","--extension",default=".txt",help="extension for recognizer output, default: %(default)s")
# parser.add_argument("-g","--gtextension",default=".gt.txt",help="extension for ground truth, default: %(default)s")
parser.add_argument("-k","--kind",default="exact",help="kind of comparison (exact, nospace, letdig, letters, digits, lnc), default: %(default)s")
parser.add_argument("-e","--erroronly",action="store_true",help="only output an error rate")
parser.add_argument("-s","--skipmissing",action="store_true",help="don't use missing or empty output files in the calculation")
parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count())
args = parser.parse_args()
args.files = ocrolib.glob_all(args.files)

def levenshtein(a,b):
    """Calculates the Levenshtein distance between a and b. 
    (Clever compact Pythonic implementation from hetland.org)"""
    n, m = len(a), len(b)
    if n > m: a,b = b,a; n,m = m,n       
    current = range(n+1)
    for i in range(1,m+1):
        previous,current = current,[i]+[0]*n
        for j in range(1,n+1):
            add,delete = previous[j]+1,current[j-1]+1
            change = previous[j-1]
            if a[j-1]!=b[i-1]: change = change+1
            current[j] = min(add, delete, change)
    return current[n]

# this is the same list of replacements we use for
# alignment; it mainly cleans up the quote mess

replacements = [
    (u'"',u"''"),     # typewriter double quote
    (u"`",u"'"),      # grave accent
    (u"\u00b4",u"'"),   # acute accent
    (u"\u2018",u"'"), # left single quotation mark
    (u"\u2019",u"'"), # right single quotation mark
    (u"\u201c",u"''"), # left double quotation mark
    (u"\u201d",u"''"), # right double quotation mark
]

def normalize(s):
    s = re.sub('[_~#]',"''",s)
    s = re.sub('`',"''",s)
    s = re.sub('"',"'",s)
    s = re.sub('[“”]',"''",s)
    s = re.sub(r'\s',' ',s)
    s = re.sub(r'\n','',s)
    s = re.sub(r'^\s+','',s)
    s = re.sub(r'\s+$','',s)
    s = re.sub(r'( *[.] *){4,}','....',s)
    for m,r in replacements:
        s = re.sub(m,r,s)
    s = unicodedata.normalize('NFKD',s)
    if args.kind=="exact":
        return s
    if args.kind=="nospace":
        return re.sub(r'\s','',s)
    if args.kind=="spletdig":
        return re.sub(r'[^A-Za-z0-9 ]','',s)
    if args.kind=="letdig":
        return re.sub(r'[^A-Za-z0-9]','',s)
    if args.kind=="letters":
        return re.sub(r'[^A-Za-z]','',s)
    if args.kind=="digits":
        return re.sub(r'[^0-9]','',s)
    if args.kind=="lnc":
        s = s.upper()
        return re.sub(r'[^A-Z]','',s)
    raise Exception("unknown normalization: "+args.kind)

if not ".gt." in args.files[0]:
    sys.stderr.write("warning: compare on .gt.txt files, not .txt files\n")


def process1(fname):
    # fgt = ocrolib.allsplitext(fname)[0]+args.gtextension
    gt = normalize(ocrolib.read_text(fname))
    ftxt = ocrolib.allsplitext(fname)[0]+args.extension
    missing = 0
    if os.path.exists(ftxt):
        txt = normalize(ocrolib.read_text(ftxt))
    else:
        missing = len(gt)
        txt = ""
    err = levenshtein(txt,gt)
    return fname,err,len(gt),missing

outputs = ocrolib.parallel_map(process1,args.files,parallel=args.parallel,chunksize=10)

errs = 0
total = 0
missing = 0
for fname,e,t,m in sorted(outputs):
    if not args.erroronly:
        print "%6d\t%6d\t%s"%(e,t,fname)
    errs += e
    total += t
    missing += m

sys.stderr.write("errors    %8d\n"%errs)
sys.stderr.write("missing   %8d\n"%missing)
sys.stderr.write("total     %8d\n"%total)
sys.stderr.write("err       %8.3f %%\n"%(errs*100.0/total))
sys.stderr.write("errnomiss %8.3f %%\n"%((errs-missing)*100.0/total))

if args.erroronly:
    print errs*1.0/total
